import torch
import torch.nn as nn
from torch.nn import Dropout
import torch.nn.functional as F
from torch.nn import Linear
from torch.nn.parameter import Parameter
from math import ceil, floor, sqrt

from functools import partial

import numpy as np


# working when Ta<=Tb
# Ta*Da->Ta*k->Ta*Tb, directional
# this is able to deal with both aligned and unaligned inputs, just give it 2 sequences.
class MyCDC(nn.Module):

  def __init__(self,
               kernel_size,
               x_size,
               y_size,
               x_len,
               y_len,
               num_heads=1,
               weight_dropout=0.,
               bias=True):

    super(MyCDC, self).__init__()
    # set the kernel_size for enough overlapping
    ks_base = floor(y_len / x_len) // 2 * 2
    self.kernel_size = kernel_size + ks_base

    self.padding = self.kernel_size // 2 if self.kernel_size % 2 == 1 else (
        (self.kernel_size - 1) // 2, self.kernel_size // 2)

    self.num_heads = num_heads
    self.weight_dropout = Dropout(weight_dropout, inplace=False)

    self.x_size = x_size
    self.y_size = y_size
    self.x_len = x_len
    self.y_len = y_len

    self.weight_act = partial(F.softmax, dim=-1)
    self.weight_transform = Linear(
        self.x_size, self.num_heads * self.kernel_size, bias=bias)
    self.output_transform = Linear(self.y_size, self.x_size, bias=bias)

    # for context gated things
    self.learned_weight = Parameter(
        torch.Tensor(self.num_heads, self.kernel_size))
    nn.init.kaiming_uniform_(self.learned_weight, a=sqrt(5))

    self.expandParamsCal()

  def forward(self, query, key, value=None, attn_mask=None):
    # query:[Tx,B,Cx]
    # key:[Ty,B,Cy]
    K, H = self.kernel_size, self.num_heads
    Tx, B, Cx = query.size()
    Ty, _, Cy = key.size()

    # Tx,B,Cx -> Tx,B,H*K
    weight = self.weight_transform(query)

    weight = self.weight_dropout(weight)

    # softmax normalize
    # Tx,B,H*K -> Tx*B,H,K
    weight = weight.view(Tx * B, H, K)
    weight = self.weight_act(weight)

    # put softmax on the learned weight instead of the context-weight is harmful. I've tested.
    weight = weight * self.learned_weight
    # Tx,B,H*K -> B*H,Tx,K
    weight = weight.view(Tx, B * H, K).transpose(0, 1)

    # expand weight to band matrix
    weight_expanded = self.expandUnaligned(weight, Tx, Ty, B, K, H)

    # bmm and proj
    key = key.view(Ty, B * H, -1).transpose(0, 1)
    output = torch.bmm(weight_expanded, key)
    output = output.transpose(0, 1).contiguous().view(Tx, B, Cy)
    output = self.output_transform(output)

    return output, None  #??? placeholder?

  def expandParamsCal(self):
    Ta = self.x_len
    Tb = self.y_len
    K = self.kernel_size
    H = self.num_heads

    # step_x first priority
    step_x_round = round((Tb - K) / (Ta - 1))
    step_x_lists = sorted(
        list(range(0, K)), key=lambda key: abs(key - step_x_round))
    self.good_flag = False
    for step_x in step_x_lists:
      # there must be overlap between blocks
      step_y_lists = sorted(
          list(range(-step_x, K - step_x)), key=lambda key: abs(key))
      for step_y in step_y_lists:
        # at most cut out the length of a kernel
        for out_offset in range(0, K):
          # step_y!=0 means there must be more than 1 block
          if step_y != 0:
            n_blk, res = divmod(Tb + out_offset - K - step_x * (Ta - 1), step_y)
            self.good_good_flag = (res == 0 and n_blk > 0)
          else:
            n_blk, res = 0, 0
            self.good_good_flag = (K + step_x * (Ta - 1) == Tb + out_offset)
          if self.good_good_flag:
            # notice that the n_blk above is actually n_blk-1
            # assume that no whole blk is truncated
            n_blk = abs(n_blk) + 1
            # how to know this solutin actually works?
            # block size of x dir
            blk_szx = ceil(Ta / n_blk)
            # padded total size of x dir
            Tx = blk_szx * n_blk
            # delta size of x dir
            del_x = Tx - Ta

            # padded total size of y dir
            Ty = Tb + step_x * del_x + out_offset
            # padding in x dir
            pad_up = del_x // 2
            pad_down = del_x - pad_up
            # calculate cut position for y dir
            pad_left = pad_up * step_x + out_offset // 2
            if pad_up < blk_szx and pad_down < blk_szx:
              self.good_flag = True

          if self.good_flag:
            break
        if self.good_flag:
          break
      if self.good_flag:
        break

    # check the calculation during the init
    if not self.good_flag:
      print(f'Ta={self.x_len},Tb={self.y_len},K={self.kernel_size},no soluion')

    # records
    self.weight_padding_list = (0, 0, pad_up, pad_down)
    self.weight_padded_view_shape = [-1, n_blk, blk_szx, K]
    self.weight_expand_shape = [-1, Tx, Ty]
    self.weight_expand_as_strided_steps = (
        Tx * Ty,
        (Ty + step_x) * blk_szx + step_y,
        Ty + step_x,
        1,
    )
    self.pad_left = pad_left
    self.pad_up = pad_up

    self.step_x = step_x
    self.step_y = step_y
    self.n_blk = n_blk
    self.blk_szx = blk_szx
    self.out_offset = out_offset

  def expandUnaligned(self, weight, Ta, Tb, B, K, H):
    self.weight_expand_shape[0] = self.weight_padded_view_shape[0] = B * H

    weight_padded = F.pad(
        weight,
        self.weight_padding_list,
    ).view(self.weight_padded_view_shape)

    # creat container for the weight
    weight_expanded = weight.new_zeros(
        self.weight_expand_shape, requires_grad=False)
    # copy values using as_strided, notice that when n_blk=1, step_y is not used
    weight_expanded.as_strided(
        self.weight_padded_view_shape,
        self.weight_expand_as_strided_steps).copy_(weight_padded)

    # cut the padded things
    weight_expanded = weight_expanded.narrow(2, self.pad_left, Tb)
    weight_expanded = weight_expanded.narrow(1, self.pad_up, Ta)
    return weight_expanded


if __name__ == '__main__':
  head = 9
  dim = head * 5
  dim2 = dim + head
  bs = 7
  for i in range(2, 100 + 1):
    for j in range(2, 100 + 1):
      for sk in range(3, j - floor(j / i) // 2 * 2 + 1, 2):
        ks = sk
        seq_len = i
        seq_len2 = j

        device = 'cuda' if torch.cuda.is_available() else 'cpu'

        q = torch.rand([seq_len, bs, dim]).to(device)
        k = torch.rand([seq_len2, bs, dim2]).to(device)

        c = MyCDC(
            ks,
            dim,
            dim2,
            seq_len,
            seq_len2,
            num_heads=head,
            weight_dropout=0.,
        ).to(device)
        c(q, k)
        # input("Press Enter to continue...")
